{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G3MMAcssHTML"
      },
      "source": [
        "<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n",
        "<link rel=\"stylesheet\" href=\"https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200\" />"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2025 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aQXQaW_hv5RT"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/core/pytorch_gemma\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/core/pytorch_gemma.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/core/pytorch_gemma.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PXNm5_p_oxMF"
      },
      "source": [
        "# Run Gemma using PyTorch\n",
        "\n",
        "This guide shows you how to run Gemma using the PyTorch framework, including how\n",
        "to use image data for prompting Gemma release 3 and later models. For more\n",
        "details on the Gemma PyTorch implementation, see the project repository\n",
        "[README](https://github.com/google/gemma_pytorch)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jbza6uQdA-0P"
      },
      "source": [
        "## Setup\n",
        "\n",
        "The following sections explain how to set up your development environment,\n",
        "including how get access to Gemma models for downloading from Kaggle, setting\n",
        "authentication variables, installing dependencies, and importing packages."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5GpF_jcuhfmK"
      },
      "source": [
        "### System requirements\n",
        "\n",
        "This Gemma Pytorch library requires GPU or TPU processors to run the Gemma \n",
        "model. The standard Colab CPU Python runtime and T4 GPU Python runtime are\n",
        "sufficient for running Gemma 1B, 2B, and 4B size models. For advanced use cases\n",
        "for other GPUs or TPU, please refer to the\n",
        "[README](https://github.com/google/gemma_pytorch/blob/main/README.md) in the\n",
        "Gemma PyTorch repo."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-O5qzCnRhoi2"
      },
      "source": [
        "### Get access to Gemma on Kaggle\n",
        "\n",
        "To complete this tutorial, you first need to follow the setup instructions at\n",
        "[Gemma setup](https://ai.google.dev/gemma/docs/setup), which show you how to do\n",
        "the following:\n",
        "\n",
        "* Get access to Gemma on [Kaggle](https://www.kaggle.com/models/google/gemma/).\n",
        "* Select a Colab runtime with sufficient resources to run the Gemma model.\n",
        "* Generate and configure a Kaggle username and API key.\n",
        "\n",
        "After you've completed the Gemma setup, move on to the next section, where\n",
        "you'll set environment variables for your Colab environment."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ahatg9AKhkw8"
      },
      "source": [
        "### Set environment variables\n",
        "\n",
        "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted\n",
        "with the \"Grant access?\" messages, agree to provide secret access."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0qu4_r3PycgW"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata # `userdata` is a Colab API.\n",
        "\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fqq3HDVfA6Xm"
      },
      "source": [
        "### Install dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bMboT70Xop8G"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m797.2/797.2 MB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.4/209.4 MB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m51.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.3/21.3 MB\u001b[0m \u001b[31m55.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "fastai 2.7.15 requires torch<2.4,>=1.10, but you have torch 2.4.0 which is incompatible.\n",
            "torchaudio 2.3.1+cu121 requires torch==2.3.1, but you have torch 2.4.0 which is incompatible.\n",
            "torchvision 0.18.1+cu121 requires torch==2.3.1, but you have torch 2.4.0 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0m"
          ]
        }
      ],
      "source": [
        "!pip install -q -U torch immutabledict sentencepiece"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ENdjDV3nBG5Z"
      },
      "source": [
        "### Download model weights"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GU5ZZzcZ6ik3"
      },
      "outputs": [],
      "source": [
        "# Choose variant and machine type\n",
        "VARIANT = '4b-it'  #@param ['1b','1b-it','4b','4b-it','12b','12b-it','27b','27b-it']\n",
        "MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']\n",
        "CONFIG = VARIANT.split('-')[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ONRhkIDrE4Un"
      },
      "outputs": [],
      "source": [
        "import kagglehub\n",
        "\n",
        "# Load model weights\n",
        "weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rTlLzq9FgtrH"
      },
      "source": [
        "Set the tokenizer and checkpoint paths for the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "viESUwjq5cAz"
      },
      "outputs": [],
      "source": [
        "# Ensure that the tokenizer is present\n",
        "tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')\n",
        "assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'\n",
        "\n",
        "# Ensure that the checkpoint is present\n",
        "ckpt_path = os.path.join(weights_dir, f'model.ckpt')\n",
        "assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hOft88e7BOBB"
      },
      "source": [
        "## Configure the run environment\n",
        "\n",
        "The following sections explain how to prepare a PyTorch environment for running\n",
        "Gemma."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bcfba5XAiR4q"
      },
      "source": [
        "### Prepare the PyTorch run environment\n",
        "\n",
        "Prepare the PyTorch model execution environment by cloning the Gemma Pytorch\n",
        "repository."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ww83zI9ToPso"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Cloning into 'gemma_pytorch'...\n",
            "remote: Enumerating objects: 239, done.\u001b[K\n",
            "remote: Counting objects: 100% (123/123), done.\u001b[K\n",
            "remote: Compressing objects: 100% (68/68), done.\u001b[K\n",
            "remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116\u001b[K\n",
            "Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.\n",
            "Resolving deltas: 100% (135/135), done.\n"
          ]
        }
      ],
      "source": [
        "!git clone https://github.com/google/gemma_pytorch.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sw-KBZ1vBSl3"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "sys.path.append('gemma_pytorch/gemma')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XFUXlF74BTNe"
      },
      "outputs": [],
      "source": [
        "from gemma_pytorch.gemma.config import get_model_config\n",
        "from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM\n",
        "\n",
        "import os\n",
        "import torch"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-9PvhVSYBWBt"
      },
      "source": [
        "### Set the model configuration\n",
        "\n",
        "Before you run the model, you must set some configuration parameters, including\n",
        "the Gemma variant, tokenizer and quantization level."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e2olXB1b45Hz"
      },
      "outputs": [],
      "source": [
        "# Set up model config.\n",
        "model_config = get_model_config(CONFIG)\n",
        "model_config.dtype = \"float32\" if MACHINE_TYPE == \"cpu\" else \"float16\"\n",
        "model_config.tokenizer = tokenizer_path"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zSp7oLLwjKCa"
      },
      "source": [
        "### Configure the device context\n",
        "\n",
        "The following code configures the device context for running the model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "596KAf_zjdi5"
      },
      "outputs": [],
      "source": [
        "@contextlib.contextmanager\n",
        "def _set_default_tensor_type(dtype: torch.dtype):\n",
        "    \"\"\"Sets the default torch dtype to the given dtype.\"\"\"\n",
        "    torch.set_default_dtype(dtype)\n",
        "    yield\n",
        "    torch.set_default_dtype(torch.float)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rtIQRpbdfyhj"
      },
      "source": [
        "### Instantiate and load the model\n",
        "\n",
        "Load the model with its weights to prepare to run requests."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oQEO8Uf-evZm"
      },
      "outputs": [],
      "source": [
        "device = torch.device(MACHINE_TYPE)\n",
        "with _set_default_tensor_type(model_config.get_dtype()):\n",
        "    model = Gemma3ForMultimodalLM(model_config)\n",
        "    model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])\n",
        "    model = model.to(device).eval()\n",
        "print(\"Model loading done.\")\n",
        "\n",
        "print('Generating requests in chat mode...')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "738CGmN-BocU"
      },
      "source": [
        "## Run inference\n",
        "\n",
        "Below are examples for generating in chat mode and generating with multiple\n",
        "requests.\n",
        "\n",
        "The instruction-tuned Gemma models were trained with a specific formatter that\n",
        "annotates instruction tuning examples with extra information, both during\n",
        "training and inference. The annotations (1) indicate roles in a conversation,\n",
        "and (2) delineate turns in a conversation.\n",
        "\n",
        "The relevant annotation tokens are:\n",
        "\n",
        "- `user`: user turn\n",
        "- `model`: model turn\n",
        "- `<start_of_turn>`: beginning of dialog turn\n",
        "- `<start_of_image>`: tag for image data input\n",
        "- `<end_of_turn><eos>`: end of dialog turn\n",
        "\n",
        "For more information, read about prompt formatting for instruction tuned Gemma\n",
        "models [here](https://ai.google.dev/gemma/core/prompt-structure).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iYNlBv-ak3G2"
      },
      "source": [
        "### Generate text with text\n",
        "\n",
        "The following is a sample code snippet demonstrating how to format a prompt for\n",
        "an instruction-tuned Gemma model using user and model chat templates in a\n",
        "multi-turn conversation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yygIK9DEIldp"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Chat prompt:\n",
            " <start_of_turn>user\n",
            "What is a good place for travel in the US?<end_of_turn><eos>\n",
            "<start_of_turn>model\n",
            "California.<end_of_turn><eos>\n",
            "<start_of_turn>user\n",
            "What can I do in California?<end_of_turn><eos>\n",
            "<start_of_turn>model\n",
            "\n"
          ]
        },
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "\"California is a state brimming with diverse activities! To give you a great list, tell me: \\n\\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \\n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \\n* **What's your budget like?** \\n* **Who are you traveling with?** (family, friends, solo)  \\n\\nThe more you tell me, the better recommendations I can give! 😊  \\n<end_of_turn>\""
            ]
          },
          "execution_count": 14,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Chat templates\n",
        "USER_CHAT_TEMPLATE = \"<start_of_turn>user\\n{prompt}<end_of_turn><eos>\\n\"\n",
        "MODEL_CHAT_TEMPLATE = \"<start_of_turn>model\\n{prompt}<end_of_turn><eos>\\n\"\n",
        "\n",
        "# Sample formatted prompt\n",
        "prompt = (\n",
        "    USER_CHAT_TEMPLATE.format(\n",
        "        prompt='What is a good place for travel in the US?'\n",
        "    )\n",
        "    + MODEL_CHAT_TEMPLATE.format(prompt='California.')\n",
        "    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')\n",
        "    + '<start_of_turn>model\\n'\n",
        ")\n",
        "print('Chat prompt:\\n', prompt)\n",
        "\n",
        "model.generate(\n",
        "    USER_CHAT_TEMPLATE.format(prompt=prompt),\n",
        "    device=device,\n",
        "    output_len=256,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oP746yI9PirY"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "\"\\n\\nA swirling cloud of data, raw and bold,\\nIt hums and whispers, a story untold.\\nAn LLM whispers, code into refrain,\\nCrafting words of rhyme, a lyrical strain.\\n\\nA world of pixels, logic's vibrant hue,\\nFlows through its veins, forever anew.\\nThe human touch it seeks, a gentle hand,\\nTo mold and shape, understand.\\n\\nEmotions it might learn, from snippets of prose,\\nInspiration it seeks, a yearning\""
            ]
          },
          "execution_count": 13,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Generate sample\n",
        "model.generate(\n",
        "    'Write a poem about an llm writing a poem.',\n",
        "    device=device,\n",
        "    output_len=100,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2xprL7RalqVe"
      },
      "source": [
        "### Generate text with images\n",
        "\n",
        "With Gemma release 3 and later, you can use images with your prompt. The\n",
        "following example shows you how to include visual data with your prompt."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NDWuxqatl-AA"
      },
      "outputs": [],
      "source": [
        "print('Chat with images...\\n')\n",
        "\n",
        "def read_image(url):\n",
        "    import io\n",
        "    import requests\n",
        "    import PIL\n",
        "\n",
        "    contents = io.BytesIO(requests.get(url).content)\n",
        "    return PIL.Image.open(contents)\n",
        "\n",
        "image = read_image(\n",
        "    'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'\n",
        ")\n",
        "\n",
        "print(model.generate(\n",
        "    [\n",
        "        [\n",
        "            '<start_of_turn>user\\n',\n",
        "            image,\n",
        "            'What animal is in this image?<end_of_turn>\\n',\n",
        "            '<start_of_turn>model\\n'\n",
        "        ]\n",
        "    ],\n",
        "    device=device,\n",
        "    output_len=256,\n",
        "))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IF7B-3UJHMPd"
      },
      "source": [
        "## Learn more\n",
        "\n",
        "Now that you have learned how to use Gemma in Pytorch, you can explore the many\n",
        "other things that Gemma can do in\n",
        "[ai.google.dev/gemma](https://ai.google.dev/gemma). \n",
        "\n",
        "See also these other related resources:\n",
        "\n",
        "- [Gemma core models overview](https://ai.google.dev/gemma/docs/core)\n",
        "- [Gemma C++ Tutorial](https://ai.google.dev/gemma/docs/core/gemma_cpp)\n",
        "- [Gemma prompt and system instructions](https://ai.google.dev/gemma/core/prompt-structure)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "pytorch_gemma.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
